{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import onmt\n",
    "import onmt.inputters\n",
    "import onmt.modules\n",
    "import onmt.utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We begin by loading in the vocabulary for the model of interest. This will let us check vocab size and to get the special ids for padding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = dict(torch.load(\"../../data/data.vocab.pt\"))\n",
    "src_padding = vocab[\"src\"].stoi[onmt.inputters.PAD_WORD]\n",
    "tgt_padding = vocab[\"tgt\"].stoi[onmt.inputters.PAD_WORD]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we specify the core model itself. Here we will build a small model with an encoder and an attention based input feeding decoder. Both models will be RNNs and the encoder will be bidirectional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb_size = 10\n",
    "rnn_size = 6\n",
    "# Specify the core model. \n",
    "encoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab[\"src\"]),\n",
    "                                             word_padding_idx=src_padding)\n",
    "\n",
    "encoder = onmt.encoders.RNNEncoder(hidden_size=rnn_size, num_layers=1, \n",
    "                                 rnn_type=\"LSTM\", bidirectional=True,\n",
    "                                 embeddings=encoder_embeddings)\n",
    "\n",
    "decoder_embeddings = onmt.modules.Embeddings(emb_size, len(vocab[\"tgt\"]),\n",
    "                                             word_padding_idx=tgt_padding)\n",
    "decoder = onmt.decoders.decoder.InputFeedRNNDecoder(hidden_size=rnn_size, num_layers=1, \n",
    "                                           bidirectional_encoder=True,\n",
    "                                           rnn_type=\"LSTM\", embeddings=decoder_embeddings)\n",
    "model = onmt.models.model.NMTModel(encoder, decoder)\n",
    "\n",
    "# Specify the tgt word generator and loss computation module\n",
    "model.generator = nn.Sequential(                                                                                                                        \n",
    "            nn.Linear(rnn_size, len(vocab[\"tgt\"])),                                                                                           \n",
    "            nn.LogSoftmax())\n",
    "loss = onmt.utils.loss.NMTLossCompute(model.generator, vocab[\"tgt\"]) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we set up the optimizer. This could be a core torch optim class, or our wrapper which handles learning rate updates and gradient normalization automatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optim = onmt.utils.optimizers.Optimizer(method=\"sgd\", lr=1, max_grad_norm=2)\n",
    "optim.set_parameters(model.named_parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we load the data from disk. Currently will need to call a function to load the fields into the data as well. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load some data\n",
    "data = torch.load(\"../../data/data.train.1.pt\")\n",
    "valid_data = torch.load(\"../../data/data.valid.1.pt\")\n",
    "data.load_fields(vocab)\n",
    "valid_data.load_fields(vocab)\n",
    "data.examples = data.examples[:100]                                    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To iterate through the data itself we use a torchtext iterator class. We specify one for both the training and test data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iter = onmt.inputters.OrderedIterator(                                                                                                                            \n",
    "                dataset=data, batch_size=10, \n",
    "                device=-1,                                                                                                                                                                                 \n",
    "                repeat=False)\n",
    "valid_iter = onmt.inputters.OrderedIterator(                                                                                                                            \n",
    "                dataset=valid_data, batch_size=10,                                                                                                                                                                                       \n",
    "                device=-1,\n",
    "                train=False) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch  0,     0/   10; acc:   0.00; ppl: 1225.23; 1320 src tok/s; 1320 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     1/   10; acc:   9.50; ppl: 996.33; 1188 src tok/s; 1194 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     2/   10; acc:  16.51; ppl: 694.48; 1265 src tok/s; 1267 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     3/   10; acc:  20.49; ppl: 470.39; 1459 src tok/s; 1420 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     4/   10; acc:  22.68; ppl: 387.03; 1511 src tok/s; 1462 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     5/   10; acc:  24.58; ppl: 345.44; 1625 src tok/s; 1509 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     6/   10; acc:  25.37; ppl: 314.39; 1586 src tok/s; 1493 tgt tok/s; 1514090454 s elapsed\n",
      "Epoch  0,     7/   10; acc:  26.14; ppl: 291.15; 1593 src tok/s; 1520 tgt tok/s; 1514090455 s elapsed\n",
      "Epoch  0,     8/   10; acc:  26.32; ppl: 274.79; 1606 src tok/s; 1545 tgt tok/s; 1514090455 s elapsed\n",
      "Epoch  0,     9/   10; acc:  26.83; ppl: 247.32; 1669 src tok/s; 1614 tgt tok/s; 1514090455 s elapsed\n",
      "Validation\n",
      "Epoch  0,    11/   10; acc:  13.41; ppl: 111.94;   0 src tok/s; 7329 tgt tok/s; 1514090464 s elapsed\n",
      "Epoch  1,     0/   10; acc:   6.59; ppl: 147.05; 1849 src tok/s; 1743 tgt tok/s; 1514090464 s elapsed\n",
      "Epoch  1,     1/   10; acc:  22.10; ppl: 130.66; 2002 src tok/s; 1957 tgt tok/s; 1514090464 s elapsed\n",
      "Epoch  1,     2/   10; acc:  20.16; ppl: 122.49; 1748 src tok/s; 1760 tgt tok/s; 1514090464 s elapsed\n",
      "Epoch  1,     3/   10; acc:  23.52; ppl: 117.41; 1690 src tok/s; 1698 tgt tok/s; 1514090464 s elapsed\n",
      "Epoch  1,     4/   10; acc:  24.16; ppl: 119.42; 1647 src tok/s; 1662 tgt tok/s; 1514090464 s elapsed\n",
      "Epoch  1,     5/   10; acc:  25.44; ppl: 115.31; 1775 src tok/s; 1709 tgt tok/s; 1514090465 s elapsed\n",
      "Epoch  1,     6/   10; acc:  24.05; ppl: 115.11; 1780 src tok/s; 1718 tgt tok/s; 1514090465 s elapsed\n",
      "Epoch  1,     7/   10; acc:  25.32; ppl: 109.59; 1799 src tok/s; 1765 tgt tok/s; 1514090465 s elapsed\n",
      "Epoch  1,     8/   10; acc:  25.14; ppl: 108.16; 1771 src tok/s; 1734 tgt tok/s; 1514090465 s elapsed\n",
      "Epoch  1,     9/   10; acc:  25.58; ppl: 107.13; 1817 src tok/s; 1757 tgt tok/s; 1514090465 s elapsed\n",
      "Validation\n",
      "Epoch  1,    11/   10; acc:  19.58; ppl:  88.09;   0 src tok/s; 7371 tgt tok/s; 1514090474 s elapsed\n"
     ]
    }
   ],
   "source": [
    "trainer = onmt.Trainer(model, loss, loss, optim)\n",
    "\n",
    "def report_func(*args):\n",
    "    stats = args[-1]\n",
    "    stats.output(args[0], args[1], 10, 0)\n",
    "    return stats\n",
    "\n",
    "for epoch in range(2):\n",
    "    trainer.train(epoch, report_func)\n",
    "    val_stats = trainer.validate()\n",
    "\n",
    "    print(\"Validation\")\n",
    "    val_stats.output(epoch, 11, 10, 0)\n",
    "    trainer.epoch_step(val_stats.ppl(), epoch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use the model, we need to load up the translation functions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import onmt.translate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PRED SCORE: -4.0690\n",
      "\n",
      "SENT 0: ('The', 'competitors', 'have', 'other', 'advantages', ',', 'too', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.2736\n",
      "\n",
      "SENT 0: ('The', 'company', '&apos;s', 'durability', 'goes', 'back', 'to', 'its', 'first', 'boss', ',', 'a', 'visionary', ',', 'Thomas', 'J.', 'Watson', 'Sr.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.0144\n",
      "\n",
      "SENT 0: ('&quot;', 'From', 'what', 'we', 'know', 'today', ',', 'you', 'have', 'to', 'ask', 'how', 'I', 'could', 'be', 'so', 'wrong', '.', '&quot;')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.1361\n",
      "\n",
      "SENT 0: ('Boeing', 'Co', 'shares', 'rose', '1.5%', 'to', '$', '67.94', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.1382\n",
      "\n",
      "SENT 0: ('Some', 'did', 'not', 'believe', 'him', ',', 'they', 'said', 'that', 'he', 'got', 'dizzy', 'even', 'in', 'the', 'truck', ',', 'but', 'always', 'wanted', 'to', 'fulfill', 'his', 'dream', ',', 'that', 'of', 'becoming', 'a', 'pilot', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -3.8881\n",
      "\n",
      "SENT 0: ('In', 'your', 'opinion', ',', 'the', 'council', 'should', 'ensure', 'that', 'the', 'band', 'immediately', 'above', 'the', 'Ronda', 'de', 'Dalt', 'should', 'provide', 'in', 'its', 'entirety', ',', 'an', 'area', 'of', 'equipment', 'to', 'conduct', 'a', 'smooth', 'transition', 'between', 'the', 'city', 'and', 'the', 'green', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.0778\n",
      "\n",
      "SENT 0: ('The', 'clerk', 'of', 'the', 'court', ',', 'Jorge', 'Yanez', ',', 'went', 'to', 'the', 'jail', 'of', 'the', 'municipality', 'of', 'San', 'Nicolas', 'of', 'Garza', 'to', 'notify', 'Jonah', 'that', 'he', 'has', 'been', 'legally', 'pardoned', 'and', 'his', 'record', 'will', 'be', 'filed', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.2479\n",
      "\n",
      "SENT 0: ('&quot;', 'In', 'a', 'research', 'it', 'is', 'reported', 'that', 'there', 'are', 'no', 'parts', 'or', 'components', 'of', 'the', 'ship', 'in', 'another', 'place', ',', 'the', 'impact', 'is', 'presented', 'in', 'a', 'structural', 'way', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -3.8585\n",
      "\n",
      "SENT 0: ('On', 'the', 'asphalt', 'covering', ',', 'he', 'added', ',', 'is', 'placed', 'a', 'final', 'layer', 'called', 'rolling', 'covering', ',', 'which', 'is', 'made', '\\u200b', '\\u200b', 'of', 'a', 'fine', 'stone', 'material', ',', 'meaning', 'sand', 'also', 'dipped', 'into', 'the', 'asphalt', '.')\n",
      "PRED 0: .\n",
      "\n",
      "PRED SCORE: -4.2298\n",
      "\n",
      "SENT 0: ('This', 'is', '200', 'bar', 'on', 'leaving', 'and', '100', 'bar', 'on', 'arrival', '.')\n",
      "PRED 0: .\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/torch/tensor.py:297: UserWarning: other is not broadcastable to self, but they have the same number of elements.  Falling back to deprecated pointwise behavior.\n",
      "  return self.add_(other)\n"
     ]
    }
   ],
   "source": [
    "translator = onmt.translate.Translator(beam_size=10, fields=data.fields, model=model)\n",
    "builder = onmt.translate.TranslationBuilder(data=valid_data, fields=data.fields)\n",
    "\n",
    "valid_data.src_vocabs\n",
    "for batch in valid_iter:\n",
    "    trans_batch = translator.translate_batch(batch=batch, data=valid_data)\n",
    "    translations = builder.from_batch(trans_batch)\n",
    "    for trans in translations:\n",
    "        print(trans.log(0))\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
